package com.huan.springboot.websocket;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Repository;
import org.yeauty.pojo.Session;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * @author huan.fu
 * @date 2022/4/26 - 14:17
 */
@Repository
public class InMemoryWebsocketSessionRepository implements WebsocketSessionRepository {
    private static final Logger log = LoggerFactory.getLogger(InMemoryWebsocketSessionRepository.class);

    private ConcurrentHashMap<String, List<Session>> userSessions = new ConcurrentHashMap<>(256);

    @Override
    public void addSession(Session session) {
        String userId = session.getAttribute("userId");
        log.info("从session中获取的userId为:[{}]", userId);

        // 此处先不处理并发的问题
        if (userSessions.containsKey(userId)) {
            userSessions.get(userId).add(session);
        } else {
            List<Session> sessions = new ArrayList<>();
            sessions.add(session);
            userSessions.put(userId, sessions);
        }
    }

    @Override
    public List<String> fetchAllUserIds() {
        return new ArrayList<>(userSessions.keySet());
    }

    @Override
    public List<Session> findSession(String userId) {
        return userSessions.get(userId);
    }

    @Override
    public List<Session> findAllSessions() {
        return userSessions.values()
                .stream()
                .flatMap(Collection::stream)
                .collect(Collectors.toList());
    }
}
